#!/usr/bin/env python3
#
# Cross Platform and Multi Architecture Advanced Binary Emulation Framework
#

"""
This module is intended for general purpose functions that are only used in qiling.os
"""

from typing import Callable, Iterable, Iterator, List, MutableMapping, Sequence, Tuple, TypeVar, Union
from uuid import UUID

from qiling import Qiling
from qiling.const import QL_VERBOSE

# TODO: separate windows-specific implementation
from qiling.os.windows.structs import make_unicode_string


class QlOsUtils:
    ELLIPSIS_PREF = r'__qlva_'

    def __init__(self, ql: Qiling):
        self.ql = ql

    def read_string(self, address: int, encoding: str, maxlen: int = 0) -> str:
        """Read a null-terminated string from memory.

        Args:
            address : starting address
            encoding: string encoding to use
            maxlen  : limit number of characters to read before reaching null terminator,
                      0 for unlimited length

        Returns: decoded string
        """

        terminator = '\x00'.encode(encoding)

        data = bytearray()
        charlen = len(terminator)
        strlen = 0

        while True:
            char = self.ql.mem.read(address, charlen)

            if char == terminator:
                break

            data += char
            strlen += 1

            if strlen == maxlen:
                break

            address += charlen

        s = data.decode(encoding, errors='backslashreplace')
        self.ql.os.stats.log_string(s)

        return s

    def read_wstring(self, address: int, maxlen: int = 0) -> str:
        """Read a null-terminated wide string from memory.
        """

        return self.read_string(address, 'utf-16le', maxlen)

    def read_cstring(self, address: int, maxlen: int = 0) -> str:
        """Read a null-terminated ASCII string from memory.
        """

        return self.read_string(address, 'latin1', maxlen)

    def read_guid(self, address: int) -> UUID:
        raw_guid = self.ql.mem.read(address, 16)

        return UUID(bytes_le=bytes(raw_guid))

    @staticmethod
    def stringify(s: str) -> str:
        """Decorate a string with quotation marks.
        """

        return f'"{repr(s)[1:-1]}"'

    def print_function(self, address: int, fname: str, pargs: Sequence[Tuple[str, str]], ret: Union[int, str, None], passthru: bool):
        '''Print out function invocation detais.

        Args:
            address: fucntion address
            fnamr: function name
            pargs: processed args list: a sequence of 2-tuples consisting of arg names paired to string representation of arg values
            ret: function return value, or None if no such value
            passthru: whether this is a passthrough invocation (no frame unwinding)
        '''

        if fname.startswith('hook_'):
            fname = fname[5:]

        def __assign_arg(name: str, value: str) -> str:
            # ignore arg names generated by variadric functions
            if name.startswith(QlOsUtils.ELLIPSIS_PREF):
                name = ''

            return f'{name} = {value}' if name else f'{value}'

        # arguments list
        fargs = ', '.join(__assign_arg(name, value) for name, value in pargs)

        if type(ret) is int:
            ret = f'{ret:#x}'

        # optional prefixes and suffixes
        fret = f' = {ret}' if ret is not None else ''
        fpass = f' (PASSTHRU)' if passthru else ''
        faddr = f'{address:#0{self.ql.arch.bits // 4 + 2}x}: ' if self.ql.verbose >= QL_VERBOSE.DEBUG else ''

        log = f'{faddr}{fname}({fargs}){fret}{fpass}'

        if self.ql.verbose >= QL_VERBOSE.DEBUG:
            self.ql.log.debug(log)
        else:
            self.ql.log.info(log)

    def __common_printf(self, format: str, va_args: Iterator[int], wstring: bool) -> Tuple[str, Sequence[int]]:
        import re

        # https://docs.microsoft.com/en-us/cpp/c-runtime-library/format-specification-syntax-printf-and-wprintf-functions
        # %[flags][width][.precision][size]type
        fmtstr = re.compile(r'''%
            (?P<follows>%|
                (?P<flags>[-+0 #]+)?
                (?P<width>[*]|[0-9]+)?
                (?:.(?P<precision>[*]|[0-9]+))?
                (?P<size>hh|ll|I32|I64|[hjltwzIL])?
                (?P<type>[diopuaAcCeEfFgGsSxXZ])
            )
        ''', re.VERBOSE)

        T = TypeVar('T')

        def __dup(iterator: Iterator[T], out: List[T]) -> Iterator[T]:
            """A wrapper iterator to record iterator elements as they are being yielded.
            """

            for elem in iterator:
                out.append(elem)
                yield elem

        repl_args = []  # processed arguments
        orig_args = []  # original arguments

        va_list = __dup(va_args, orig_args)

        read_string = self.read_wstring if wstring else self.read_cstring

        def __repl(m: re.Match) -> str:
            """Convert printf format string tokens into Python's.
            """

            if m['follows'] == '%':
                return '%%'

            else:
                flags = m['flags'] or ''

                fill  = ' ' if ' ' in flags else ''
                align = '<' if '-' in flags else ''
                sign  = '+' if '+' in flags else ''
                pound = '#' if '#' in flags else ''
                zeros = '0' if '0' in flags else ''

                width = m['width'] or ''

                if width == '*':
                    width = f'{next(va_list)}'

                prec = m['precision'] or ''

                if prec == '*':
                    prec = f'{next(va_list)}'

                if prec:
                    prec = f'.{prec}'

                typ = m['type']
                arg = next(va_list)

                if typ in 'sS':
                    typ = 's'
                    arg = read_string(arg)

                elif typ == 'Z':
                    # note: ANSI_STRING and UNICODE_STRING have identical layout
                    ucstr_struct = make_unicode_string(self.ql.arch.bits)

                    with ucstr_struct.ref(self.ql.mem, arg) as ucstr_obj:
                        typ = 's'
                        arg = read_string(ucstr_obj.Buffer)

                elif typ == 'p':
                    pound = '#'
                    typ = 'x'

                repl_args.append(arg)

                return f'%{fill}{align}{sign}{pound}{zeros}{width}{prec}{typ}'

        out = fmtstr.sub(__repl, format)

        return out % tuple(repl_args), orig_args

    def va_list(self, ptr: int) -> Iterator[int]:
        while True:
            yield self.ql.mem.read_ptr(ptr)

            ptr += self.ql.arch.pointersize

    def sprintf(self, buff: int, format: str, va_args: Iterator[int], wstring: bool = False) -> Tuple[int, Callable]:
        out, args = self.__common_printf(format, va_args, wstring)
        enc = 'utf-16le' if wstring else 'utf-8'

        self.ql.mem.write(buff, (out + '\x00').encode(enc))

        return len(out), self.__update_ellipsis(args)

    def printf(self, format: str, va_args: Iterator[int], wstring: bool = False) -> Tuple[int, Callable]:
        out, args = self.__common_printf(format, va_args, wstring)
        enc = 'utf-8'

        self.ql.os.stdout.write(out.encode(enc))

        return len(out), self.__update_ellipsis(args)

    def __update_ellipsis(self, args: Iterable[int]) -> Callable[[MutableMapping], None]:
        def __do_update(params: MutableMapping) -> None:
            params.update((f'{QlOsUtils.ELLIPSIS_PREF}{i}', a) for i, a in enumerate(args))

        return __do_update